author:
- Tao Lei
- Yu Zhang
- Yoav Artzi
submission:
- ICLR 2018 Conference Blind Submission
year: "2018"
file:
- "[[01. Training_rnns_as_fast_as_cnns.pdf|04. Training_rnns_as_fast_as_cnns]]"
related:
tags:
- "#Recurrent-Learning"
review date: 2024-11-12
Summary
RNN의 병렬 처리 한계를 극복하기 위해 계산의 대부분을 독립적으로 처리할 수 있는 Simple Recurrent Unit(SRU)을 제안함 [Abstract, p.1]
RNN의 sequential dependency로 인한 병렬화 한계를 해결하기 위해 대부분의 계산을 독립적으로 수행할 수 있는 SRU 구조를 제안함 [Section 1 Introduction, p.1]
CUDA 레벨 최적화를 통해 기존 LSTM 대비 5-10배 빠른 처리 속도를 달성함 [Section 1 Introduction, Figure 2, p.1-2]
텍스트 분류, QA, 언어 모델링, 번역, 음성인식 등 다양한 태스크에서 좋은 성능을 보임 [Section 1 Introduction, p.1]
RNN의 sequential computation으로 인한 병렬화 한계를 지적 [Section 1 Introduction, p.1]
대부분의 계산을 time step 간 독립적으로 처리할 수 있는 SRU 구조를 제안 [Section 2.1, p.2]
Matrix multiplication을 batch 처리하고 element-wise 연산을 하나의 kernel function으로 융합하는 CUDA 레벨 최적화 구현 [Section 2.3, p.3]
LSTM 대비 5-10배 빠른 처리 속도로 CNN과 비슷한 수준의 속도를 달성 [Figure 2, Section 1 Introduction, p.2]
다양한 NLP/Speech 태스크에서 기존 모델과 비슷하거나 더 나은 성능을 보임 [Section 4 Experiments, p.4-9]
일반적인 RNN 구조는 state computation의 병렬화가 어려워 확장성이 떨어짐. 이를 해결하기 위해 대부분의 계산을 병렬화할 수 있는 SRU를 제안함 [Abstract, p.1]
RNN은 sequential dependency로 인해 각 step의 계산이 이전 step의 계산에 의존적임. 이로 인해 convolution이나 attention과 달리 병렬화가 어려움 [Section 1 Introduction, p.1]
RNN의 sequential computation으로 인한 병렬화의 어려움 [Section 1 Introduction, p.1]
이로 인한 느린 학습/추론 속도와 파라미터 튜닝의 어려움 [Section 1 Introduction, p.1]
대부분의 계산을 time step 간 독립적으로 수행할 수 있는 SRU 구조 설계 [Section 2.1, p.2]
The basic form of SRU includes a single forget gate. Given an input xt at time t, we compute a linear transformation x̃t and the forget gate ft:
x̃t = Wxt
ft = σ(Wf xt + bf)
This computation depends on xt only, which enables computing it in parallel across all time steps. The forget gate is used to modulate the internal state ct, which is used to compute the output state ht:
ct = ft ◦ ct-1 + (1 - ft) ◦ x̃t
ht = g(ct)
Matrix multiplication을 batch 처리하고 element-wise 연산을 하나의 kernel로 fusion하는 CUDA 최적화 [Section 2.3, Algorithm 1, p.3-4]
The SRU formulation permits two optimizations:
1. Matrix multiplications across all time steps can be batched:
U⊤ = [W Wf Wr][x1, x2, ..., xn]
where n is the sequence length, U ∈ Rn×3d is the resulting matrix
2. All element-wise operations of the sequence can be fused into one kernel function and parallelized across the dimensionality of the hidden state d. Without the fusion, operations such as addition + and sigmoid activation σ() would each invoke a separate function call.
Highway connection 추가로 deep network의 학습 안정성 향상 [Section 2.1, p.2]
Highway connection은 깊은 네트워크의 학습을 개선하는 것으로 알려진 skip connection을 포함
reset gate rt를 추가하여 내부 상태 g(ct)와 입력 xt의 조합으로 출력 상태 ht를 계산합니다. 이는 깊은 네트워크의 학습 안정성을 향상시킴
The complete architecture also includes skip connections, which have been shown to improve training of deep networks with a large number of layers. We use highway connections, and add a reset gate rt computed similar to the forget gate ft. The reset gate is used to compute the output state ht as a combination of the internal state g(ct) and the input xt:
rt = σ(Wr xt + br)
ht = rt ◦ g(ct) + (1 - rt) ◦ xt